-
Notifications
You must be signed in to change notification settings - Fork 24.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[ML] Natural Language Processing tasks and models #73523
[ML] Natural Language Processing tasks and models #73523
Conversation
Pinging @elastic/ml-core (Team:ML) |
run elasticsearch-ci/part-1 |
jenkins test this please |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good. Just a couple of test related comments.
import org.elasticsearch.test.ESTestCase; | ||
|
||
|
||
public class TaskTypeTests extends ESTestCase { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This one is left empty. Should we add some tests here?
import java.util.List; | ||
import java.util.stream.Collectors; | ||
|
||
public class FillMaskProcessor implements NlpTask.Processor { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should add some tests for this one
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM Just a question about the name of the fill mask results field. Good to merge though even if you decide to change that.
@@ -26,25 +26,25 @@ | |||
public static final String NAME = "fill_mask_result"; | |||
public static final String DEFAULT_RESULTS_FIELD = "results"; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should this also be predictions
?
@@ -27,40 +26,49 @@ | |||
this.bertRequestBuilder = new BertRequestBuilder(tokenizer); | |||
} | |||
|
|||
@Override | |||
public void validateInputs(String inputs) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nice!
The feature branch contains changes to configure PyTorch models with a TrainedModelConfig and defines a format to store the binary models. The _start and _stop deployment actions control the model lifecycle and the model can be directly evaluated with the _infer endpoint. 2 Types of NLP tasks are supported: Named Entity Recognition and Fill Mask. The feature branch consists of these PRs: #73523, #72218, #71679 #71323, #71035, #71177, #70713
Following on from #72218 which defined how large PyTorch models can be stored, this PR introduces the concepts of Natural Language Processing tasks and defines a way to evaluate BERT models.
Mask Fill and Named Entity Recognition tasks are implemented here but others could be easily added now the framework is in place. In particular this PR implements tokenisation of input text for BERT models and defines a structure for post-graph processing.
Once the PyTorch model is uploaded a trained model config referencing it must be PUT
And the model deployed:
Mask Fill Example
Returns
NER Example
Returns:
Feature branch PR
Co-authored-by: Dimitris Athanasiou [email protected]